{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Understanding Annotation Consolidation: A SageMaker Ground Truth Demonstration for Image Classification"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Introduction\n",
    "This sample notebook demonstrates how Ground Truth annotation consolidation works and shows the performance improvement achieved by the built-in annotation consolidation algorithm for image classification over alternate approaches.\n",
    "\n",
    "We start with the output of a labeling job where 5 workers annotated 302 images of birds from [Google Open Images Dataset](https://storage.googleapis.com/openimages/web/index.html)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Disclosure regarding the Open Images Dataset V4\n",
    "\n",
    "Open Images Dataset V4 is created by Google Inc. In some cases we have modified the images or the accompanying annotations. You can obtain the original images and annotations here. The annotations are licensed by Google Inc. under [CC BY 4.0](https://creativecommons.org/licenses/by/4.0/) license. The images are listed as having a [CC BY 2.0](https://creativecommons.org/licenses/by/2.0/) license. The following paper describes Open Images V4 in depth: from the data collection and annotation to detailed statistics about the data and evaluation of models trained on it.\n",
    "\n",
    "A. Kuznetsova, H. Rom, N. Alldrin, J. Uijlings, I. Krasin, J. Pont-Tuset, S. Kamali, S. Popov, M. Malloci, T. Duerig, and V. Ferrari. The Open Images Dataset V4: Unified image classification, object detection, and visual relationship detection at scale. arXiv:1811.00982, 2018. [(pdf)](https://arxiv.org/abs/1811.00982)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "from collections import defaultdict\n",
    "import itertools\n",
    "import json\n",
    "import random\n",
    "import subprocess\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.ticker import NullFormatter\n",
    "from sklearn.metrics import confusion_matrix\n",
    "import boto3\n",
    "import glob\n",
    "from scipy import stats\n",
    "import pandas as pd \n",
    "from io import BytesIO\n",
    "from PIL import Image\n",
    "import pickle\n",
    "\n",
    "# Set up the files you will need for the analysis.\n",
    "subprocess.run(['tar', '-xvf', 'ic_data.tar.gz'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Analyze Ground Truth labeling job results\n",
    "\n",
    "Whether you use the output of the job we already ran or a new job, we are now ready to analyze the results, all contained in the output manifest. If you ran your own job, you can find the location of the output manifest under AWS Console > SageMaker > Labeling Jobs > [name of your job]. For our pre-completed job, we will obtain it programmatically in the cell below.\n",
    "\n",
    "### Postprocess the output manifest\n",
    "\n",
    "First, download the output manifest and postprocess it to form:\n",
    "\n",
    "* `img_uris` -- array containing the S3 URIs of all the images that Ground Truth annotated.\n",
    "* `true_labels` -- dictionary storing the true label categories for all the images that Ground Truth annotated. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_true_label(src_ref, src_to_label):\n",
    "    uid = src_ref.replace('s3://birdstop-ic-us-west-2/birds-open-images-subset/all_data/', '').replace('.jpg', '')\n",
    "    return src_to_label[uid]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('ic_data/output.manifest', 'r') as f:\n",
    "    output = [json.loads(line.strip()) for line in f.readlines()]\n",
    "\n",
    "# Create and initialize data arrays.\n",
    "img_uris = [None] * len(output)\n",
    "\n",
    "# Find the job name the manifest corresponds to.\n",
    "keys = list(output[0].keys())\n",
    "metakey = [key for key in keys if ('metadata' in key)][0]\n",
    "jobname = metakey.replace('-metadata', '')\n",
    "\n",
    "#load the true class labels \n",
    "categories = ['Duck', 'Sparrow', 'Woodpecker', 'Owl', 'Parrot', 'Turkey', 'Falcon', 'Swan', 'Goose', 'Canary']\n",
    "with open('ic_data/true_labels.json', 'r') as input_file:    \n",
    "    true_labels_file = json.load(input_file)\n",
    "    \n",
    "src_to_label = {}\n",
    "for true_label in true_labels_file:\n",
    "    src_to_label[true_label['ImageId']] = true_label['Label']\n",
    "    \n",
    "true_labels = {}\n",
    "# Extract the data.\n",
    "for datum_id, datum in enumerate(output):\n",
    "    true_labels[datum_id] = find_true_label(datum['source-ref'], src_to_label)\n",
    "    img_uris[datum_id] = datum['source-ref']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot histogram of true class labels\n",
    "\n",
    "This is some preliminary data analysis to understand the class distribution. Our dataset is quite unbalanced and has categories that can be confused with each other. For example, annotators may confuse sparrows, canaries and parrots."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_classes = len(categories)\n",
    "unique_classes, class_counts = np.unique(list(true_labels.values()), return_counts=True)\n",
    "sorted_class_count_idx = np.argsort(class_counts)[::-1]\n",
    "sorted_unique_classes = unique_classes[sorted_class_count_idx]\n",
    "sorted_class_counts = class_counts[sorted_class_count_idx]\n",
    "\n",
    "plt.figure(figsize=(8, 3), facecolor='white', dpi=100)\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.title('Class distribution')\n",
    "plt.bar(range(n_classes), sorted_class_counts, width=0.7)\n",
    "plt.xticks(range(n_classes), sorted_unique_classes, rotation='vertical')\n",
    "plt.ylabel('Class Count')\n",
    "plt.grid(which='both', alpha=0.3);\n",
    "\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.title('Class distribution (log scale)')\n",
    "plt.bar(range(n_classes), sorted_class_counts, width=0.7, log=True)\n",
    "plt.xticks(range(n_classes), sorted_unique_classes, rotation='vertical')\n",
    "plt.grid(which='both', alpha=0.3);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Comparison of Consolidation Methods\n",
    "We now compare the performance of our consolidation algorithm, Modified Dawid-Skene (MDS), with the standard baselines of Single Worker (SW) and Majority Voting (MV). For MDS and MV, we can see how the performance changes as the number of annotators increases."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DataLayer(object):\n",
    "    \n",
    "    \"\"\"\n",
    "    This is a simple substitute for the actual data layer class, for use in local testing.\n",
    "    It stores and retrieves pickles from a dictionary in memory.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        self.worker_params = defaultdict(dict)\n",
    "        self.label_params = defaultdict(dict)\n",
    "\n",
    "    def put_label_information_s3(self, label_data, dataset_object_id, labeling_job_arn):\n",
    "        self.label_params[labeling_job_arn][dataset_object_id] = pickle.dumps(label_data)\n",
    "\n",
    "    def get_label_information_s3(self, dataset_object_id, labeling_job_arn):\n",
    "        label_data = self.label_params.get(labeling_job_arn, {}).get(dataset_object_id, None)\n",
    "        if label_data:\n",
    "            label_data = pickle.loads(label_data)\n",
    "        return label_data\n",
    "\n",
    "    def put_worker_information_s3(self, worker_data, worker_id, labeling_job_arn):\n",
    "        self.worker_params[labeling_job_arn][worker_id] = pickle.dumps(worker_data)\n",
    "\n",
    "    def get_worker_information_s3(self, worker_id, labeling_job_arn):\n",
    "        worker_data = self.worker_params.get(labeling_job_arn, {}).get(worker_id, None)\n",
    "        if worker_data:\n",
    "            worker_data = pickle.loads(worker_data)\n",
    "        return worker_data\n",
    "\n",
    "class MulticlassDawidSkeneEM(object):\n",
    "    \"\"\"\n",
    "    Implements the method described in A. P. Dawid and A. M. Skene, 1979, Maximum Likelihood Estimation of\n",
    "    Observer Error-Rates Using the EM Algorithm, Journal of the Royal Statistical Society\n",
    "    Series C (Applied Statistics), Vol. 28, No. 1, pp. 20-28\n",
    "    https://www.jstor.org/stable/2346806\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, labeling_job_arn, output_config=None, role_arn = None, kms_key_id = None, identifier=\"Testing\"):\n",
    "        self.labeling_job_arn = labeling_job_arn\n",
    "        self.dataset_object_ids = set()\n",
    "        self.worker_ids = set()\n",
    "        self.l_ij = defaultdict(dict)  # A dict of dicts to store the annotations in Lij format\n",
    "        self.p_prior = None  # Item priors will be an np.array len n_classes, storing the class marginal distribution\n",
    "        self.max_epochs = 20\n",
    "        self.min_relative_diff = 1E-8\n",
    "        self.identifier = identifier\n",
    "        self.data_layer = DataLayer()\n",
    "\n",
    "    def update(self, annotation_payload, label_categories, label_attribute_name, is_text=False):\n",
    "        \"\"\"\n",
    "        Update the worker and item parameters, based on a new batch of data\n",
    "        :param annotation_payload: The payload of annotations received, which is a list of items, where each item dict\n",
    "        contains a dataset_object_id, and list of annotations. Each annotation is a dict with a worker_id and\n",
    "        information on the annotation content provided by that worker.\n",
    "        :param label_categories: The list of possible categories for the multiclass classification\n",
    "        :param label_attribute_name: The named assigned to this collection of labels\n",
    "        :param is_text: Denotes that input and output are in text (not image) classification format\n",
    "        :return: The updated item parameters\n",
    "        \"\"\"\n",
    "        all_worker_prior = 0.7\n",
    "        p, c_mtrx = self.get_or_initialize_parameters(annotation_payload, label_categories, all_worker_prior)\n",
    "        log_likelihood = None\n",
    "        for epoch in range(self.max_epochs):\n",
    "            p, p_non_normalized = self.expectation_step(self.l_ij, p, c_mtrx, self.n_classes)\n",
    "            c_mtrx, worker_priors = self.maximization_step(self.l_ij, p, self.n_classes,\n",
    "                                                           self.worker_ids, all_worker_prior)\n",
    "            log_likelihood, relative_diff = self.calc_log_likelihood(\n",
    "                self.l_ij, p_non_normalized, log_likelihood\n",
    "            )\n",
    "            if relative_diff is not None and relative_diff < self.min_relative_diff:\n",
    "                self.put_parameters(p, c_mtrx)\n",
    "                responses = self.format_responses(p, label_categories, label_attribute_name, is_text)\n",
    "                return responses\n",
    "\n",
    "            all_worker_prior = sum([worker_priors[j] for j in worker_priors]) / len(worker_priors)\n",
    "\n",
    "        self.put_parameters(p, c_mtrx)\n",
    "        responses = self.format_responses(p, label_categories, label_attribute_name, is_text)\n",
    "        return responses\n",
    "\n",
    "    def get_or_initialize_parameters(self, annotation_payload, label_categories, all_worker_prior):\n",
    "        \"\"\"\n",
    "        Sets the dataset object_ids and worker_ids, gets the item and worker params if they exist, or initializes\n",
    "        them if they do not.\n",
    "        :param annotation_payload: The payload of annotations received, which is a list of items, where each item dict\n",
    "        contains a dataset_object_id, and list of annotations. Each annotation is a dict with a worker_id and\n",
    "        information on the annotation content provided by that worker.\n",
    "        :param label_categories: The list of possible categories for the multiclass classification\n",
    "        :param all_worker_prior: The assumed prior accuracy rate of an average worker\n",
    "        :return: None\n",
    "        \"\"\"\n",
    "\n",
    "        self.label_categories = label_categories\n",
    "        self.n_classes = len(label_categories)\n",
    "\n",
    "        # Store the dataset object_ids and worker_ids, and store the annotation dataset in Lij form\n",
    "        for item in annotation_payload:\n",
    "            i = item['datasetObjectId']\n",
    "            self.dataset_object_ids.add(i)\n",
    "            for annotation in item['annotations']:\n",
    "                j = annotation['workerId']\n",
    "                self.worker_ids.add(j)\n",
    "                annotation_content = annotation['annotationData']['content']\n",
    "                self.l_ij[i][j] = self.label_categories.index(annotation_content)\n",
    "\n",
    "        # Get or initialize the item parameters\n",
    "        # Item params are a dict of np label-class arrays, keyed by dataset_object_id (i)\n",
    "        p = {}\n",
    "        for i in self.dataset_object_ids:\n",
    "            # item_params = self.data_layer.get_label_information_s3(i, self.labeling_job_arn)\n",
    "            item_params = self.initialize_item_parameters(n_classes=self.n_classes)\n",
    "            p[i] = item_params\n",
    "\n",
    "        # Get or initialize the worker parameters\n",
    "        # Worker params are a dict of np confusion matrices, keyed by worker_id (j)\n",
    "        c_mtrx = {}\n",
    "        for j in self.worker_ids:\n",
    "            # worker_params = self.data_layer.get_worker_information_s3(j, self.labeling_job_arn)\n",
    "            worker_params = self.initialize_worker_params(n_classes=self.n_classes, a=all_worker_prior)\n",
    "            c_mtrx[j] = worker_params\n",
    "\n",
    "        return p, c_mtrx\n",
    "\n",
    "    def put_parameters(self,  p, c_mtrx):\n",
    "        \"\"\"\n",
    "        Write the item and worker parameters back to the data layer\n",
    "        :return: None\n",
    "        \"\"\"\n",
    "        # Write the item parameters back to the data layer\n",
    "        for i in self.dataset_object_ids:\n",
    "            pickled_label_data = pickle.dumps(p[i])\n",
    "            self.data_layer.put_label_information_s3(pickled_label_data, self.labeling_job_arn,  i)\n",
    "\n",
    "        # Write the worker parameters back to the data layer\n",
    "        for j in self.worker_ids:\n",
    "            pickled_worker_data = pickle.dumps(c_mtrx[j])\n",
    "            self.data_layer.put_worker_information_s3(pickled_worker_data, self.labeling_job_arn, j)\n",
    "\n",
    "    @staticmethod\n",
    "    def initialize_item_parameters(n_classes):\n",
    "        \"\"\"\n",
    "        Initializes item parameters to an even probability distribution across all classes\n",
    "        :param n_classes: The number of classes\n",
    "        :return: item_parameters\n",
    "        \"\"\"\n",
    "        return np.ones(n_classes) / n_classes\n",
    "\n",
    "    @staticmethod\n",
    "    def initialize_worker_params(n_classes, a=0.7):\n",
    "        \"\"\"\n",
    "        Initializes worker parameters to an a confusion matrix with a default accuracy down the diagonal\n",
    "        :param n_classes: The number of classes\n",
    "        :param a: The assumed accuracy of a typical worker, for initializing confusion matrices\n",
    "        :return: worker_params\n",
    "        \"\"\"\n",
    "        worker_params = np.ones((n_classes, n_classes)) * ((1 - a) / (n_classes - 1))\n",
    "        np.fill_diagonal(worker_params, a)\n",
    "        return worker_params\n",
    "\n",
    "    @staticmethod\n",
    "    def expectation_step(l_ij, p, c_mtrx, n_classes):\n",
    "        \"\"\"\n",
    "        The update of the true class probabilities, following equations 2.3 and 2.4 in Dawid-Skene (1979)\n",
    "        :param l_ij: The annotated data, in Lij format\n",
    "        :param p: The current estimate of the true class parameters (dict keyed on dataset_object_id)\n",
    "        :param c_mtrx: The worker confusion matrices (dict keyed on worker_id)\n",
    "        :param n_classes: The number of classes\n",
    "        :return: the updated item params (p), and a non-normalized version of them to use in estimating\n",
    "        the log-likelihood of the data\n",
    "        \"\"\"\n",
    "        # Set our prior value of p to be the marginal class distribution across all items\n",
    "        p_prior = np.zeros(n_classes)\n",
    "        for i in p:\n",
    "            p_prior += p[i]\n",
    "        p_prior /= p_prior.sum()\n",
    "\n",
    "        for i in l_ij:\n",
    "            # Initialize the item params to the prior value\n",
    "            p[i] = p_prior.copy()\n",
    "            for j in l_ij[i]:\n",
    "                annotated_class = l_ij[i][j]\n",
    "                for true_class in range(n_classes):\n",
    "                    error_rate = c_mtrx[j][true_class, annotated_class]\n",
    "                    # Update the prior with likelihoods from the confusion matrix\n",
    "                    p[i][true_class] *= error_rate\n",
    "\n",
    "        p_non_normalized = p.copy()\n",
    "        # Normalize the item parameters\n",
    "        for i in p:\n",
    "            if p[i].sum() > 0:\n",
    "                p[i] /= float(p[i].sum())\n",
    "        return p, p_non_normalized\n",
    "\n",
    "    def maximization_step(self, l_ij, p, n_classes, worker_ids, all_worker_prior):\n",
    "        \"\"\"\n",
    "        Update of the worker confusion matrices, following equation 2.5 of Dawid-Skene (1979)\n",
    "        :param l_ij: The annotated data, in Lij format\n",
    "        :param p: The current estimate of the true class parameters\n",
    "        :param n_classes: The number of classes\n",
    "        :param worker_ids: the set of worker_ids\n",
    "        :param all_worker_prior: The prior accuracy of an average worker\n",
    "        :return: The updated worker confusion matrices, c_mtrx\n",
    "        \"\"\"\n",
    "        # Calculate the updated confusion matrices, based on the new item parameters\n",
    "        all_worker_prior_mtrx = self.initialize_worker_params(n_classes, a=all_worker_prior)\n",
    "\n",
    "        c_mtrx = {}\n",
    "        worker_accuracies = {}\n",
    "        for j in worker_ids:\n",
    "            c_mtrx[j] = np.zeros((n_classes, n_classes))\n",
    "        for i in l_ij:\n",
    "            for j in l_ij[i]:\n",
    "                annotated_class = l_ij[i][j]\n",
    "                for true_class in range(n_classes):\n",
    "                    c_mtrx[j][true_class, annotated_class] += p[i][true_class]\n",
    "\n",
    "        for j in worker_ids:\n",
    "            num_annotations = c_mtrx[j].sum()\n",
    "            worker_accuracies[j] = c_mtrx[j].diagonal().sum() / num_annotations\n",
    "            worker_prior_mtrx = self.initialize_worker_params(n_classes, a=worker_accuracies[j])\n",
    "            c_mtrx[j] += (worker_prior_mtrx * num_annotations + all_worker_prior_mtrx * num_annotations / 2)\n",
    "\n",
    "            # Perform dirichlet update to get new confusion matrices\n",
    "            for true_class in range(n_classes):\n",
    "                if c_mtrx[j][true_class].sum() > 0:\n",
    "                    c_mtrx[j][true_class] /= float(c_mtrx[j][true_class].sum())\n",
    "\n",
    "        return c_mtrx, worker_accuracies\n",
    "\n",
    "    @staticmethod\n",
    "    def calc_log_likelihood(l_ij, p_non_normalized, prev_log_likelihood=None):\n",
    "        \"\"\"\n",
    "        Calculate the log-likelihood of the data, so that when it stops improving, we can stop iterating\n",
    "        :param l_ij: The annotated data, in Lij format\n",
    "        :param p_non_normalized: The non-normalized item parameters\n",
    "        :param prev_log_likelihood: The log-likelihood from the previous epoch\n",
    "        :return: the log-likelihood of the data, and the relative difference from the previous epoch's log-likelihood\n",
    "        \"\"\"\n",
    "        log_likelihood = 0.0\n",
    "        relative_diff = None\n",
    "        for i in l_ij:\n",
    "            posterior_i = p_non_normalized[i]\n",
    "            likelihood_i = posterior_i.sum()\n",
    "            log_likelihood += np.log(likelihood_i)\n",
    "\n",
    "        if prev_log_likelihood:\n",
    "            diff = log_likelihood - prev_log_likelihood\n",
    "            relative_diff = diff / prev_log_likelihood\n",
    "\n",
    "        return log_likelihood, relative_diff\n",
    "\n",
    "    def format_responses(self, params, label_categories, label_attribute_name, is_text):\n",
    "        responses = []\n",
    "        for dataset_object_id in params:\n",
    "            label_estimate = params[dataset_object_id]\n",
    "            confidence_score = round(max(label_estimate), 2)\n",
    "            label, index = self.retrieve_annotation(label_estimate, label_categories)\n",
    "            consolidated_annotation = self.transform_to_label(label, index, label_attribute_name,\n",
    "                                                              confidence_score, is_text)\n",
    "            response = self.build_response(dataset_object_id, consolidated_annotation)\n",
    "            responses.append(response)\n",
    "        return responses\n",
    "\n",
    "    def transform_to_label(self, estimated_label, index, label_attribute_name, confidence_score, is_text):\n",
    "        if is_text:\n",
    "            return self.transform_to_text_label(estimated_label, index, label_attribute_name, confidence_score)\n",
    "        else:\n",
    "            return self.transform_to_image_label(estimated_label, index, label_attribute_name, confidence_score)\n",
    "\n",
    "    def transform_to_image_label(self, estimated_label, index, label_attribute_name, confidence_score):\n",
    "        return {\n",
    "            label_attribute_name: int(float(index)),\n",
    "            label_attribute_name + \"-metadata\": {\"class-name\": estimated_label, \"job-name\": self.labeling_job_arn,\n",
    "                                                 \"confidence\": confidence_score, \"type\": \"groundtruth/text-classification\",\n",
    "                                                 \"human-annotated\": \"yes\", \"creation-date\": 'date'},\n",
    "        }\n",
    "\n",
    "    @staticmethod\n",
    "    def retrieve_annotation(label_estimate, label_categories):\n",
    "        elem = label_categories[np.argmax(label_estimate, axis=0)]\n",
    "        index = label_categories.index(elem)\n",
    "        return elem, index\n",
    "\n",
    "    @staticmethod\n",
    "    def build_response(dataset_object_id, consolidated_annotation):\n",
    "        return {\n",
    "            \"datasetObjectId\": dataset_object_id,\n",
    "            \"consolidatedAnnotation\": {\n",
    "                'content': consolidated_annotation\n",
    "            }\n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def most_common(labels):\n",
    "    unique_classes, class_votes = np.unique(labels, return_counts=True)\n",
    "    winning_num_votes = np.max(class_votes)\n",
    "    winning_class = unique_classes[np.where(class_votes == winning_num_votes)]\n",
    "    if len(winning_class) == 1:  # clear majority\n",
    "        return winning_class[0]\n",
    "    else:                        # break ties randomly\n",
    "        return np.random.choice(winning_class)  \n",
    "\n",
    "def majority_vote(dset_objects):\n",
    "    final_labels = []\n",
    "    for dset_object in dset_objects:\n",
    "        labels = []\n",
    "        for annotation in dset_object['annotations']:\n",
    "            label = annotation['annotationData']['content']\n",
    "            labels.append(label)\n",
    "        winner = most_common(labels)\n",
    "        final_labels.append({ 'datasetObjectId': dset_object['datasetObjectId'],\n",
    "            'consolidatedAnnotation': {'content': {'categories-metadata': {'class-name': winner}}}})\n",
    "    return final_labels\n",
    "\n",
    "def map_labels_to_raw_annotations(dset_objects):\n",
    "    raw_annotations_with_ground_truth = []\n",
    "    for dset_object in dset_objects:\n",
    "        true_label = true_labels[dset_object['datasetObjectId']]\n",
    "        for annotation in dset_object['annotations']:\n",
    "            label = annotation['annotationData']['content']\n",
    "            raw_annotations_with_ground_truth.append({\n",
    "                'Predicted Label': label,   # Single worker raw label\n",
    "                'True Label': true_label\n",
    "            })\n",
    "    return raw_annotations_with_ground_truth\n",
    "            \n",
    "def compute_accuracy(annotated_labels):\n",
    "    consolidated_annotations_with_ground_truth = []\n",
    "    num_right = 0 \n",
    "    for label in annotated_labels:\n",
    "        dset_object_id = label['datasetObjectId']\n",
    "        true_label = true_labels[dset_object_id]\n",
    "        cons_label = label['consolidatedAnnotation']['content']['categories-metadata']['class-name']\n",
    "        consolidated_annotations_with_ground_truth.append({\n",
    "            'Predicted Label': cons_label,\n",
    "            'True Label': true_label,\n",
    "            'Dataset Object ID': dset_object_id\n",
    "        })\n",
    "        if cons_label == true_label:\n",
    "            num_right = 1 + num_right\n",
    "    return consolidated_annotations_with_ground_truth, num_right / len(annotated_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Simulate experiments with different numbers of annotators\n",
    "\n",
    "In order to simulate multiple runs with different numbers of annotators, we sample 20 times from our five-annotator dataset for each annotator count. With all five annotations, there is no source of performance variation in MDS, but MV still has some variation due to the random tie breaking. Running a real experiment 20 different times with different workers would lead to performance variation in MDS as well. However, for the sake of simplicity we use the same dataset to simulate runs. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Collect all the individual worker annotations\n",
    "dset_to_annotations = []\n",
    "for annot_fname in glob.glob('ic_data/worker-response/**', recursive=True):\n",
    "    if annot_fname.endswith('json'):\n",
    "        with open(annot_fname, 'r') as f:\n",
    "            annot_data = json.load(f)\n",
    "        dset_id = int(annot_fname.split('/')[3])\n",
    "        annotations = []\n",
    "        for answer in annot_data['answers']:\n",
    "            label = answer['answerContent']['crowd-image-classifier']['label'].replace(' ', '')\n",
    "            worker_id = answer['workerId']\n",
    "            annotations.append({'workerId': worker_id, 'annotationData':  {'content': label}})\n",
    "        dset_annotations = {'datasetObjectId': dset_id,'annotations': annotations}\n",
    "        dset_to_annotations.append(dset_annotations)\n",
    "\n",
    "label_attribute_name = 'categories'\n",
    "\n",
    "RAW_ANNOTATIONS_WITH_GROUND_TRUTH = map_labels_to_raw_annotations(dset_to_annotations)\n",
    "\n",
    "# Find the number of annotators by looking at the annotations.\n",
    "n_annotators = len(list(dset_annotations.values())[1])\n",
    "\n",
    "dset_to_annotations_by_worker_count = {}\n",
    "for i in range(1, n_annotators + 1):\n",
    "    dset_to_annotations_by_worker_count[i] = dset_to_annotations\n",
    "    \n",
    "# Run num_iter iterations of both algorithms by sampling from the data, so we can\n",
    "# plot both the mean and the error bars for the performance\n",
    "num_iter = 20\n",
    "ds_accuracies = defaultdict(list)\n",
    "mv_accuracies = defaultdict(list)\n",
    "DS_CONSOLIDATED_ANNOTATIONS_WITH_GROUND_TRUTH_ALL_RUNS = defaultdict(dict)\n",
    "MV_CONSOLIDATED_ANNOTATIONS_WITH_GROUND_TRUTH_ALL_RUNS = defaultdict(dict)\n",
    "for num_workers in dset_to_annotations_by_worker_count:\n",
    "    for i in range(num_iter):\n",
    "        annotations = []\n",
    "        for dset in dset_to_annotations:\n",
    "            sample = random.sample(dset['annotations'], num_workers)\n",
    "            vote = {'datasetObjectId': dset['datasetObjectId'],'annotations': sample }\n",
    "            annotations.append(vote)\n",
    "            \n",
    "        dawid_skene = MulticlassDawidSkeneEM(jobname)\n",
    "\n",
    "        DS_CONSOLIDATED_ANNOTATIONS_WITH_GROUND_TRUTH_ALL_RUNS[num_workers][i], ds_accuracy = compute_accuracy(dawid_skene.update(annotations , \n",
    "                                                                                                         categories, label_attribute_name))\n",
    "        MV_CONSOLIDATED_ANNOTATIONS_WITH_GROUND_TRUTH_ALL_RUNS[num_workers][i], mv_accuracy = compute_accuracy(majority_vote(annotations))\n",
    "        ds_accuracies[num_workers].append(ds_accuracy)\n",
    "        mv_accuracies[num_workers].append(mv_accuracy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot the performance as a function of the number of annotators"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mv_errors = {k:1-np.array(v) for k,v in mv_accuracies.items()}\n",
    "mv_err_mean = {k:np.mean(v) for k,v in mv_errors.items()}\n",
    "mv_err_sem = {k:stats.sem(v) for k,v in mv_errors.items()}\n",
    "\n",
    "ds_errors = {k:1-np.array(v) for k,v in ds_accuracies.items()}\n",
    "ds_err_mean = {k:np.mean(v) for k,v in ds_errors.items()}\n",
    "ds_err_sem = {k:stats.sem(v) for k,v in ds_errors.items()}\n",
    "\n",
    "annotator_range = np.arange(1, 6)\n",
    "y_mv, y_mv_sem, y_ds, y_ds_sem = [],[],[],[]\n",
    "for k in annotator_range:\n",
    "    y_mv.append(mv_err_mean[k])\n",
    "    y_mv_sem.append(mv_err_sem[k])\n",
    "    y_ds.append(ds_err_mean[k])\n",
    "    y_ds_sem.append(ds_err_sem[k])\n",
    "\n",
    "w = 0.4\n",
    "fig= plt.figure(figsize=(6, 4), facecolor='white', dpi=100)\n",
    "plt.bar(annotator_range[1:]-w/2, y_mv[1:], width=w, yerr=y_mv_sem[1:], log=True, capsize=2, label='Majority Vote')\n",
    "plt.bar(annotator_range[1:]+w/2, y_ds[1:], width=w, yerr=y_ds_sem[1:], log=True, capsize=2, label='Modified Dawid-Skene')\n",
    "sw = np.array([y_mv[0]]*6)\n",
    "sw_sem = np.array([y_mv_sem[0]]*6)\n",
    "plt.plot(range(1,7), sw, 'k-.', label='Single Worker')\n",
    "plt.fill_between(range(1,7), sw - sw_sem, sw + sw_sem, color='r', alpha=0.3)\n",
    "fig.gca().yaxis.set_minor_formatter(NullFormatter())\n",
    "plt.xlim([1.5,5.5])\n",
    "plt.xticks(annotator_range[1:])\n",
    "plt.title('Performance Comparison')\n",
    "plt.ylabel('Classification Error')\n",
    "plt.xlabel('Number of Annotators')\n",
    "plt.legend(loc='upper right', bbox_to_anchor=(1, 0.9))\n",
    "plt.grid(which='both', alpha=0.3);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Find examples where MV and MDS disagree\n",
    "\n",
    "The performance of MV and MDS can be close depending on the dataset and workforce. In the dataset we use, this is indeed the case. However, the preceding plot shows that MDS outperforms MV on average. In the following cells, we try to find a simulated run where there was at least one image that MDS got right over MV and vice versa."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "images_ds_got_right_mv_got_wrong_list = []\n",
    "images_mv_got_right_ds_got_wrong_list = []\n",
    "worker_count = 5\n",
    "for i in range(num_iter):\n",
    "    ds_data = pd.DataFrame(data=DS_CONSOLIDATED_ANNOTATIONS_WITH_GROUND_TRUTH_ALL_RUNS[worker_count][i], columns =['True Label', 'Predicted Label', 'Dataset Object ID'])\n",
    "    ds_data = ds_data.sort_values(by=['Dataset Object ID'])\n",
    "    y_true_ds = ds_data['True Label']\n",
    "    y_pred_ds = ds_data['Predicted Label']\n",
    "\n",
    "    mv_data = pd.DataFrame(data=MV_CONSOLIDATED_ANNOTATIONS_WITH_GROUND_TRUTH_ALL_RUNS[worker_count][i], columns =['True Label', 'Predicted Label', 'Dataset Object ID'])\n",
    "    mv_data = mv_data.sort_values(by=['Dataset Object ID'])\n",
    "    y_true_mv = mv_data['True Label']\n",
    "    y_pred_mv = mv_data['Predicted Label']\n",
    "\n",
    "    ds_ids = ds_data['Dataset Object ID'].values\n",
    "    images_ds_got_right = ds_ids[np.where(y_true_ds == y_pred_ds)]\n",
    "    images_ds_got_wrong = ds_ids[np.where(y_true_ds != y_pred_ds)]\n",
    "\n",
    "    mv_ids = mv_data['Dataset Object ID'].values\n",
    "    images_mv_got_right = mv_ids[np.where(y_true_mv == y_pred_mv)]\n",
    "    images_mv_got_wrong = mv_ids[np.where(y_true_mv != y_pred_mv)]\n",
    "\n",
    "    images_ds_got_right_mv_got_wrong = set(images_ds_got_right) & set(images_mv_got_wrong)\n",
    "    images_mv_got_right_ds_got_wrong = set(images_mv_got_right) & set(images_ds_got_wrong)\n",
    "    images_mv_got_wrong_ds_got_wrong = set(images_mv_got_wrong) & set(images_ds_got_wrong)\n",
    "    images_ds_got_right_mv_got_wrong_list.append(len(images_ds_got_right_mv_got_wrong))\n",
    "    images_mv_got_right_ds_got_wrong_list.append(len(images_mv_got_right_ds_got_wrong))\n",
    "    print('%%%%%%%%% Run number: {} %%%%%%%%%'.format(i))\n",
    "    print('Number of images that Majority Voting got right: {}'.format(len(images_mv_got_right)))\n",
    "    print('Number of images that Modified Dawid-Skene got right: {}'.format(len(images_ds_got_right)))\n",
    "    print('Number of images that Majority Voting got wrong but Modified Dawid-Skene got right: {}'.format(len(images_ds_got_right_mv_got_wrong)))\n",
    "    print('Number of images that Modified Dawid-Skene got wrong but Majority Voting got right: {}'.format(len(images_mv_got_right_ds_got_wrong)))\n",
    "    print('Number of images that both Modified Dawid-Skene and Majority Voting got wrong: {}'.format(len(images_mv_got_wrong_ds_got_wrong)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "runs_where_ds_made_corrections_to_mv = np.where(np.array(images_ds_got_right_mv_got_wrong_list) != 0)\n",
    "runs_where_mv_made_corrections_to_ds = np.where(np.array(images_mv_got_right_ds_got_wrong_list) != 0)\n",
    "run_where_both_made_corrections = set(runs_where_ds_made_corrections_to_mv[0]) & set(runs_where_mv_made_corrections_to_ds[0])\n",
    "if len(run_where_both_made_corrections) != 0:\n",
    "    chosen_run = np.random.choice(list(run_where_both_made_corrections))\n",
    "elif len(runs_where_ds_made_corrections_to_mv) != 0:\n",
    "    chosen_run = np.random.choice(list(runs_where_ds_made_corrections_to_mv))\n",
    "elif len(runs_where_mv_made_corrections_to_ds) != 0:\n",
    "    chosen_run = np.random.choice(list(runs_where_mv_made_corrections_to_ds))\n",
    "else:\n",
    "    chosen_run = np.random.choice(range(num_iter))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize Confusion Matrices\n",
    "Another interesting way to look at performance is through confusion matrices, which show how often classes are confused with each other. For a perfect predictor, there is no misclassification, and the confusion matrix will have 1s on the diagonal and 0s everywhere else. Let's plot these for the chosen run."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', \n",
    "                          cmap=plt.cm.Blues, xlabel='Predicted Label', ylabel='True Label'):\n",
    "    \"\"\"\n",
    "    This function prints and plots the confusion matrix.\n",
    "    Normalization can be applied by setting `normalize=True`.\n",
    "    \"\"\"\n",
    "    if normalize:\n",
    "        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n",
    "        pass\n",
    "    else:\n",
    "        pass\n",
    "\n",
    "    plt.imshow(cm, interpolation='nearest', cmap=cmap)\n",
    "    plt.title(title)\n",
    "    tick_marks = np.arange(len(classes))\n",
    "    plt.xticks(tick_marks, classes, rotation=90, fontsize=12)\n",
    "    plt.yticks(tick_marks, classes, fontsize=12)\n",
    "\n",
    "    fmt = '.2f' if normalize else 'd'\n",
    "    thresh = cm.max() / 2.\n",
    "    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n",
    "        plt.text(j, i, format(cm[i, j], fmt),\n",
    "                 horizontalalignment=\"center\",\n",
    "                 color=\"white\" if cm[i, j] > thresh else \"black\")\n",
    "\n",
    "    plt.ylabel(ylabel, fontsize=15)\n",
    "    plt.xlabel(xlabel, fontsize=15)\n",
    "    plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare the data. Compute the confusion matrices.\n",
    "raw_data = pd.DataFrame(data=RAW_ANNOTATIONS_WITH_GROUND_TRUTH, columns =['True Label', 'Predicted Label'])\n",
    "y_true_raw = raw_data['True Label']\n",
    "y_pred_raw = raw_data['Predicted Label']\n",
    "cnf_matrix_raw = confusion_matrix(y_true_raw, y_pred_raw, labels=categories)\n",
    "\n",
    "DS_CONSOLIDATED_ANNOTATIONS_WITH_GROUND_TRUTH = DS_CONSOLIDATED_ANNOTATIONS_WITH_GROUND_TRUTH_ALL_RUNS[worker_count][chosen_run]\n",
    "ds_data = pd.DataFrame(data=DS_CONSOLIDATED_ANNOTATIONS_WITH_GROUND_TRUTH, columns =['True Label', 'Predicted Label', 'Dataset Object ID'])\n",
    "ds_data = ds_data.sort_values(by=['Dataset Object ID'])\n",
    "y_true_ds = ds_data['True Label']\n",
    "y_pred_ds = ds_data['Predicted Label']\n",
    "cnf_matrix_ds = confusion_matrix(y_true_ds, y_pred_ds, labels=categories)\n",
    "\n",
    "MV_CONSOLIDATED_ANNOTATIONS_WITH_GROUND_TRUTH = MV_CONSOLIDATED_ANNOTATIONS_WITH_GROUND_TRUTH_ALL_RUNS[worker_count][chosen_run]\n",
    "mv_data = pd.DataFrame(data=MV_CONSOLIDATED_ANNOTATIONS_WITH_GROUND_TRUTH, columns =['True Label', 'Predicted Label', 'Dataset Object ID'])\n",
    "mv_data = mv_data.sort_values(by=['Dataset Object ID'])\n",
    "y_true_mv = mv_data['True Label']\n",
    "y_pred_mv = mv_data['Predicted Label']\n",
    "cnf_matrix_mv = confusion_matrix(y_true_mv, y_pred_mv, labels=categories)\n",
    "\n",
    "# Plot the confusion matrices\n",
    "plt.figure(num=None, figsize=(18, 9), dpi=80, facecolor='w', edgecolor='k')\n",
    "plt.subplot(1,3,1)\n",
    "plot_confusion_matrix(cnf_matrix_raw, classes=categories, normalize=True,\n",
    "                      xlabel='Raw Label', title='Normalized confusion matrix')\n",
    "plt.subplot(1,3,2)\n",
    "plot_confusion_matrix(cnf_matrix_ds, classes=categories, normalize=True,\n",
    "                      ylabel='', xlabel='Modified Dawid-Skene Label', title='Normalized confusion matrix')\n",
    "plt.subplot(1,3,3)\n",
    "plot_confusion_matrix(cnf_matrix_mv, classes=categories, normalize=True,\n",
    "                      ylabel='', xlabel='Majority Voting Label', title='Normalized confusion matrix')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Qualitative Comparison of Modified Dawid-Skene and Majority Voting\n",
    "Let’s now look at some qualitative results. We look at three sets of example images: (1) Images which Majority Voting got wrong but Dawid-Skene got right. (2) Images which Dawid-Skene got wrong but Majority Voting got right. (3) Images that both got wrong."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dset_id_to_annotations_dict = {}\n",
    "for dset_ann in dset_to_annotations:\n",
    "    dset_obj_id = dset_ann['datasetObjectId']\n",
    "    annotation_list = []\n",
    "    for anno_with_metadata in dset_ann['annotations']:\n",
    "        annotation_list.append(anno_with_metadata['annotationData']['content'])\n",
    "    dset_id_to_annotations_dict[dset_obj_id] = annotation_list\n",
    "    \n",
    "ds_ids = ds_data['Dataset Object ID'].values\n",
    "images_ds_got_right = ds_ids[np.where(y_true_ds == y_pred_ds)]\n",
    "images_ds_got_wrong = ds_ids[np.where(y_true_ds != y_pred_ds)]\n",
    "\n",
    "mv_ids = mv_data['Dataset Object ID'].values\n",
    "images_mv_got_right = mv_ids[np.where(y_true_mv == y_pred_mv)]\n",
    "images_mv_got_wrong = mv_ids[np.where(y_true_mv != y_pred_mv)]\n",
    "\n",
    "images_ds_got_right_mv_got_wrong = set(images_ds_got_right) & set(images_mv_got_wrong)\n",
    "images_mv_got_right_ds_got_wrong = set(images_mv_got_right) & set(images_ds_got_wrong)\n",
    "images_mv_got_wrong_ds_got_wrong = set(images_mv_got_wrong) & set(images_ds_got_wrong)\n",
    "\n",
    "print('Chosen Run: {}'.format(chosen_run))\n",
    "print('Number of images that Majority Voting got right: {}'.format(len(images_mv_got_right)))\n",
    "print('Number of images that Modified Dawid-Skene got right: {}'.format(len(images_ds_got_right)))\n",
    "print('Number of images that Majority Voting got wrong but Modified Dawid-Skene got right: {}'.format(len(images_ds_got_right_mv_got_wrong)))\n",
    "print('Number of images that Modified Dawid-Skene got wrong but Majority Voting got right: {}'.format(len(images_mv_got_right_ds_got_wrong)))\n",
    "print('Number of images that both Modified Dawid-Skene and Majority Voting got wrong: {}'.format(len(images_mv_got_wrong_ds_got_wrong)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BUCKET = 'open-images-dataset'\n",
    "\n",
    "def get_img_from_s3(bucket, key):\n",
    "    s3 = boto3.resource('s3')\n",
    "    s3_bucket = s3.Bucket(bucket)\n",
    "    obj = s3_bucket.Object(key=key).get()\n",
    "    img = BytesIO(obj['Body'].read())\n",
    "    img=Image.open(img)\n",
    "    return img\n",
    "\n",
    "def visualize_images(image_ids, annotations_dict, img_uris, true_labels, ds_pred, mv_pred, title, max_fig=4):\n",
    "    fig = plt.figure(figsize=(30, 8))\n",
    "    fig.suptitle(title, fontsize=15)\n",
    "    for i,img_id in enumerate(image_ids):\n",
    "        if i == max_fig: break\n",
    "        raw_annotations = annotations_dict[img_id]\n",
    "        true_label = true_labels[img_id]\n",
    "        ds_label = ds_pred[img_id]\n",
    "        mv_label = mv_pred[img_id]\n",
    "        IMG_KEY = img_uris[img_id].split('/')[-1]\n",
    "        KEY = 'validation/{}'.format(IMG_KEY)\n",
    "        image = get_img_from_s3(BUCKET, KEY)\n",
    "        ax = fig.add_subplot(1,max_fig,i+1)\n",
    "        plt.imshow(image)\n",
    "        ax.set_title('Annotations: {}\\nTrue Label: {}\\nMajority Voting Label: {}\\n Dawid-Skene Label: {}'.format(raw_annotations, \n",
    "                                                                            true_label, mv_label, ds_label), fontsize=12)\n",
    "        ax.axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig_title = \"Example images which Majority Voting got wrong but Dawid-Skene got right\"\n",
    "visualize_images(list(images_ds_got_right_mv_got_wrong), dset_id_to_annotations_dict, \n",
    "                 img_uris, y_true_ds.values, y_pred_ds.values, y_pred_mv.values, fig_title)\n",
    "\n",
    "fig_title = \"Example images which Dawid-Skene got wrong but Majority Voting got right\"\n",
    "visualize_images(list(images_mv_got_right_ds_got_wrong), dset_id_to_annotations_dict, \n",
    "                 img_uris, y_true_ds.values, y_pred_ds.values, y_pred_mv.values, fig_title)\n",
    "\n",
    "fig_title = \"Example images which both Dawid-Skene and Majority Voting got wrong\"\n",
    "visualize_images(list(images_mv_got_wrong_ds_got_wrong), dset_id_to_annotations_dict, \n",
    "                 img_uris, y_true_ds.values, y_pred_ds.values, y_pred_mv.values, fig_title)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
